"""
Phase 8: Pseudo-OOS Backtest
=============================
Estimate rating model on 1990-2010, project to 2024 using actual demographics,
check coverage of Monte Carlo intervals vs actual 2024 ratings.

Output: table12_oos_backtest.md
"""

import sys
import numpy as np
import pandas as pd
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows/safe_asset_cliff")
ROOT_DIR = PROJECT_DIR.parent
PROCESSED_DIR = PROJECT_DIR / "data" / "processed"
TABLES_DIR = PROJECT_DIR / "output" / "tables"

sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

SAFE_THRESHOLD = 18  # AA-

# Countries that were safe issuers as of 2010
SAFE_2010 = {
    'USA': 21, 'DEU': 21, 'GBR': 21, 'FRA': 21, 'CAN': 21, 'AUS': 21,
    'CHE': 21, 'NLD': 21, 'AUT': 21, 'DNK': 21, 'FIN': 21, 'NOR': 21,
    'SWE': 21, 'SGP': 21, 'HKG': 21, 'LUX': 21, 'NZL': 21, 'BEL': 20,
    'KWT': 19, 'QAT': 19, 'ESP': 21, 'IRL': 21, 'ITA': 20,
}

# Actual 2024 ratings (S&P 21-point scale)
ACTUAL_2024 = {
    'USA': 20, 'DEU': 21, 'GBR': 19, 'FRA': 19, 'CAN': 21, 'AUS': 21,
    'CHE': 21, 'NLD': 21, 'AUT': 20, 'DNK': 21, 'FIN': 20, 'NOR': 21,
    'SWE': 21, 'SGP': 21, 'HKG': 20, 'LUX': 21, 'NZL': 20, 'BEL': 19,
    'KWT': 19, 'QAT': 19, 'ESP': 15, 'IRL': 19, 'ITA': 14,
    'KOR': 19, 'TWN': 20, 'CZE': 19, 'ARE': 19,
}


def main():
    print("=" * 70)
    print("PHASE 8: Pseudo-OOS Backtest (1990-2010 → 2024)")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "cliff_panel.csv")
    print(f"Loaded: {df['iso3'].nunique()} countries, {len(df):,} obs")

    # ── [1] Estimate on 1990-2010 training window ──
    print("\n[1] Estimating model on 1990-2010 training window ...")
    train = df[(df['year'] >= 1990) & (df['year'] <= 2010)].copy()

    controls = ['rgdp_growth', 'inflation', 'kaopen']
    controls = [c for c in controls if c in train.columns]
    x_vars = ['old_dep', 'oadr_spline_20'] + controls

    est = train.dropna(subset=['rating_numeric'] + x_vars).copy()
    print(f"  Training sample: {est['iso3'].nunique()} countries, {len(est):,} obs")

    model = PanelGLS()
    y = est['rating_numeric'].values
    X = est[x_vars].values
    model.fit(y, X, est['iso3'].values, est['year'].values)
    model.summary(feature_names=x_vars)

    beta_full = np.concatenate([[model.constant], model.beta])
    n_params = len(beta_full)

    # Compute VCV
    import statsmodels.api as sm
    X_const = sm.add_constant(X)
    n, k = X_const.shape
    resid = y - X_const @ beta_full
    sigma2 = np.sum(resid**2) / (n - k)
    XtX_inv = np.linalg.inv(X_const.T @ X_const)
    vcv = sigma2 * XtX_inv

    # ── [2] Get actual 2024 demographics for safe-2010 countries ──
    print("\n[2] Getting actual 2024 demographics ...")
    df_2024 = df[df['year'] == 2024].copy()
    # Also try 2023 as fallback
    df_2023 = df[df['year'] == 2023].copy()

    # ── [3] Monte Carlo projection to 2024 ──
    print("\n[3] Monte Carlo projection (1000 draws) ...")
    N_SIM = 1000
    np.random.seed(42)

    try:
        beta_draws = np.random.multivariate_normal(beta_full, vcv, size=N_SIM)
    except np.linalg.LinAlgError:
        se_full = np.sqrt(np.abs(np.diag(vcv)))
        beta_draws = np.column_stack([
            np.random.normal(beta_full[j], se_full[j], N_SIM)
            for j in range(n_params)
        ])

    # Get 2010 control means for each country (used for projection)
    def get_controls(iso3):
        cdf = train[train['iso3'] == iso3].sort_values('year').tail(5)
        result = {}
        for var in controls:
            vals = cdf[var].dropna()
            result[var] = float(vals.mean()) if len(vals) > 0 else 0
        return result

    results_rows = []
    all_countries = set(SAFE_2010.keys()) & set(ACTUAL_2024.keys())

    for iso3 in sorted(all_countries):
        # Get actual 2024 demographics
        row24 = df_2024[df_2024['iso3'] == iso3]
        if len(row24) == 0:
            row24 = df_2023[df_2023['iso3'] == iso3]
        if len(row24) == 0:
            print(f"  WARNING: No 2024 data for {iso3}, skipping")
            continue

        oadr_2024 = float(row24['old_dep'].iloc[0])
        spline_2024 = max(0, oadr_2024 - 0.20)
        ctrl = get_controls(iso3)

        # Build X vector with actual 2024 demographics but training-period controls
        x_vec = np.array([1.0, oadr_2024, spline_2024] +
                         [ctrl.get(c, 0) for c in controls])

        # Point prediction
        pred_point = x_vec @ beta_full
        pred_point = np.clip(pred_point, 0, 21)

        # Monte Carlo predictions
        preds = np.array([np.clip(x_vec @ beta_draws[s], 0, 21) for s in range(N_SIM)])
        # Add rating noise
        preds_noisy = preds + np.random.normal(0, np.sqrt(sigma2) * 0.5, N_SIM)
        preds_noisy = np.clip(preds_noisy, 0, 21)

        actual = ACTUAL_2024[iso3]
        rating_2010 = SAFE_2010[iso3]

        p10 = np.percentile(preds_noisy, 10)
        p90 = np.percentile(preds_noisy, 90)
        p25 = np.percentile(preds_noisy, 25)
        p75 = np.percentile(preds_noisy, 75)
        median_pred = np.median(preds_noisy)

        # Coverage checks
        in_80 = 1 if p10 <= actual <= p90 else 0
        in_50 = 1 if p25 <= actual <= p75 else 0

        # Safe status prediction
        p_safe = np.mean(preds_noisy >= SAFE_THRESHOLD)
        actual_safe = 1 if actual >= SAFE_THRESHOLD else 0

        # Direction correct? (predicted downgrade if predicted < 2010 rating)
        pred_down = 1 if median_pred < rating_2010 - 0.5 else 0
        actual_down = 1 if actual < rating_2010 else 0
        direction_correct = 1 if pred_down == actual_down else 0

        results_rows.append({
            'iso3': iso3,
            'rating_2010': rating_2010,
            'actual_2024': actual,
            'median_pred': round(median_pred, 1),
            'p10': round(p10, 1),
            'p90': round(p90, 1),
            'in_80': in_80,
            'in_50': in_50,
            'p_safe': round(p_safe, 2),
            'actual_safe': actual_safe,
            'direction_correct': direction_correct,
            'error': round(median_pred - actual, 1),
            'oadr_2024_pct': round(oadr_2024 * 100, 1),
        })

    results_df = pd.DataFrame(results_rows)
    print(f"\n  {len(results_df)} countries projected")

    # ── [4] Summary statistics ──
    print("\n[4] Summary statistics ...")
    n = len(results_df)
    coverage_80 = results_df['in_80'].mean()
    coverage_50 = results_df['in_50'].mean()
    rmse = np.sqrt(np.mean(results_df['error']**2))
    mae = np.mean(np.abs(results_df['error']))
    direction_acc = results_df['direction_correct'].mean()

    # Safe status accuracy
    safe_correct = np.mean(
        ((results_df['p_safe'] >= 0.5) & (results_df['actual_safe'] == 1)) |
        ((results_df['p_safe'] < 0.5) & (results_df['actual_safe'] == 0))
    )

    # Naive benchmark: rating stays at 2010 level
    naive_rmse = np.sqrt(np.mean((results_df['rating_2010'] - results_df['actual_2024'])**2))

    print(f"  80% coverage: {coverage_80:.1%} (target: 80%)")
    print(f"  50% coverage: {coverage_50:.1%} (target: 50%)")
    print(f"  RMSE: {rmse:.2f} (naive: {naive_rmse:.2f})")
    print(f"  MAE: {mae:.2f}")
    print(f"  Direction accuracy: {direction_acc:.1%}")
    print(f"  Safe status accuracy: {safe_correct:.1%}")
    print(f"  Rel RMSE vs naive: {rmse/naive_rmse:.2f}")

    # Countries that lost safe status
    lost_safe = results_df[(results_df['rating_2010'] >= SAFE_THRESHOLD) &
                           (results_df['actual_2024'] < SAFE_THRESHOLD)]
    if len(lost_safe) > 0:
        print(f"\n  Countries that LOST safe status (2010→2024):")
        for _, r in lost_safe.iterrows():
            print(f"    {r['iso3']}: {r['rating_2010']}→{r['actual_2024']}, "
                  f"predicted median={r['median_pred']}, P(safe)={r['p_safe']:.2f}")

    # Countries that GAINED safe status
    print(f"\n  All results:")
    print(results_df[['iso3', 'rating_2010', 'actual_2024', 'median_pred',
                       'p10', 'p90', 'in_80', 'p_safe', 'direction_correct',
                       'error']].to_string(index=False))

    # ── [5] Write Table 12 ──
    print("\n[5] Writing Table 12 ...")
    lines = ["### Table 12: Pseudo-OOS Backtest (1990-2010 → 2024)", ""]

    # Panel A: Summary
    lines.append("**Panel A: Calibration Summary**")
    lines.append("")
    lines.append("| Metric | Value |")
    lines.append("|:--|--:|")
    lines.append(f"| Countries | {n} |")
    lines.append(f"| 80% CI coverage | {coverage_80:.1%} |")
    lines.append(f"| 50% CI coverage | {coverage_50:.1%} |")
    lines.append(f"| RMSE | {rmse:.2f} |")
    lines.append(f"| Naive RMSE (no change) | {naive_rmse:.2f} |")
    lines.append(f"| Relative RMSE | {rmse/naive_rmse:.2f} |")
    lines.append(f"| MAE | {mae:.2f} |")
    lines.append(f"| Direction accuracy | {direction_acc:.1%} |")
    lines.append(f"| Safe status accuracy | {safe_correct:.1%} |")
    lines.append("")

    # Panel B: Country detail
    lines.append("**Panel B: Country-Level Results**")
    lines.append("")
    headers = "| Country | Rating 2010 | Actual 2024 | Predicted | [P10, P90] | In 80% CI | P(safe) | Still safe |"
    lines.append(headers)
    lines.append("|:--|--:|--:|--:|:--|:--|--:|:--|")
    for _, r in results_df.sort_values('error').iterrows():
        safe_mark = "Yes" if r['actual_safe'] else "**No**"
        in_ci = "Yes" if r['in_80'] else "**No**"
        lines.append(f"| {r['iso3']} | {r['rating_2010']} | {r['actual_2024']} | "
                     f"{r['median_pred']} | [{r['p10']}, {r['p90']}] | {in_ci} | "
                     f"{r['p_safe']:.2f} | {safe_mark} |")

    lines.append("")
    lines.append("*Model estimated on 1990-2010 data using OADR spline specification. "
                 "Projections use actual 2024 demographics with training-period macro controls. "
                 "1,000 Monte Carlo draws from coefficient VCV. "
                 "Naive benchmark assumes rating unchanged from 2010.*")
    lines.append("")

    (TABLES_DIR / "table12_oos_backtest.md").write_text("\n".join(lines), encoding="utf-8")
    print(f"  Saved: {TABLES_DIR / 'table12_oos_backtest.md'}")

    print("\n" + "=" * 70)
    print("Phase 8 complete.")
    print("=" * 70)


if __name__ == "__main__":
    main()
